{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "071c7845",
   "metadata": {},
   "source": [
    "下面介绍基于循环神经网络的编码器和解码器的代码实现。首先是作为编码器的循环神经网络。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f682c7d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class RNNEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(RNNEncoder, self).__init__()\n",
    "        # 隐层大小\n",
    "        self.hidden_size = hidden_size\n",
    "        # 词表大小\n",
    "        self.vocab_size = vocab_size\n",
    "        # 词嵌入层\n",
    "        self.embedding = nn.Embedding(self.vocab_size,\\\n",
    "            self.hidden_size)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        # inputs: batch * seq_len\n",
    "        # 注意门控循环单元使用batch_first=True，因此输入需要至少batch为1\n",
    "        features = self.embedding(inputs)\n",
    "        output, hidden = self.gru(features)\n",
    "        return output, hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dc8af74",
   "metadata": {},
   "source": [
    "接下来是作为解码器的另一个循环神经网络的代码实现。<!--我们使用编码器最终的输出用作解码器的初始隐状态，这个输出向量有时称作上下文向量（context vector），它编码了整个源序列的信息。解码器最初的输入词元是“\\<sos\\>”（start-of-string）。解码器的目标为，输入编码器隐状态，一步一步解码出整个目标序列。解码器每一步的输入可以是真实目标序列，也可以是解码器上一步的预测结果。-->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "9beee561",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(RNNDecoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.vocab_size = vocab_size\n",
    "        # 序列到序列任务并不限制编码器和解码器输入同一种语言，\n",
    "        # 因此解码器也需要定义一个嵌入层\n",
    "        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "        # 用于将输出的隐状态映射为词表上的分布\n",
    "        self.out = nn.Linear(self.hidden_size, self.vocab_size)\n",
    "\n",
    "    # 解码整个序列\n",
    "    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # 从<sos>开始解码\n",
    "        decoder_input = torch.empty(batch_size, 1,\\\n",
    "            dtype=torch.long).fill_(SOS_token)\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        \n",
    "        # 如果目标序列确定，最大解码步数确定；\n",
    "        # 如果目标序列不确定，解码到最大长度\n",
    "        if target_tensor is not None:\n",
    "            seq_length = target_tensor.size(1)\n",
    "        else:\n",
    "            seq_length = MAX_LENGTH\n",
    "        \n",
    "        # 进行seq_length次解码\n",
    "        for i in range(seq_length):\n",
    "            # 每次输入一个词和一个隐状态\n",
    "            decoder_output, decoder_hidden = self.forward_step(\\\n",
    "                decoder_input, decoder_hidden)\n",
    "            decoder_outputs.append(decoder_output)\n",
    "\n",
    "            if target_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = target_tensor[:, i].unsqueeze(1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = topi.squeeze(-1).detach()\n",
    "\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # 为了与AttnRNNDecoder接口保持统一，最后输出None\n",
    "        return decoder_outputs, decoder_hidden, None\n",
    "\n",
    "    # 解码一步\n",
    "    def forward_step(self, input, hidden):\n",
    "        output = self.embedding(input)\n",
    "        output = F.relu(output)\n",
    "        output, hidden = self.gru(output, hidden)\n",
    "        output = self.out(output)\n",
    "        return output, hidden\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38306ed9",
   "metadata": {},
   "source": [
    "下面介绍基于注意力机制的循环神经网络解码器的代码实现。\n",
    "我们使用一个注意力层来计算注意力权重，其输入为解码器的输入和隐状态。\n",
    "这里使用Bahdanau注意力（Bahdanau attention），这是序列到序列模型中应用最广泛的注意力机制，特别是机器翻译任务。该注意力机制使用一个对齐模型（alignment model）来计算编码器和解码器隐状态之间的注意力分数，具体来讲就是一个前馈神经网络。相比于点乘注意力，Bahdanau注意力利用了非线性变换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d675aa6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class BahdanauAttention(nn.Module):\n",
    "    def __init__(self, hidden_size):\n",
    "        super(BahdanauAttention, self).__init__()\n",
    "        self.Wa = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Ua = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Va = nn.Linear(hidden_size, 1)\n",
    "\n",
    "    def forward(self, query, keys):\n",
    "        # query: batch * 1 * hidden_size\n",
    "        # keys: batch * seq_length * hidden_size\n",
    "        # 这一步用到了广播（broadcast）机制\n",
    "        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
    "        scores = scores.squeeze(2).unsqueeze(1)\n",
    "\n",
    "        weights = F.softmax(scores, dim=-1)\n",
    "        context = torch.bmm(weights, keys)\n",
    "        return context, weights\n",
    "\n",
    "class AttnRNNDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(AttnRNNDecoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)\n",
    "        self.attention = BahdanauAttention(hidden_size)\n",
    "        # 输入来自解码器输入和上下文向量，因此输入大小为2 * hidden_size\n",
    "        self.gru = nn.GRU(2 * self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "        # 用于将注意力的结果映射为词表上的分布\n",
    "        self.out = nn.Linear(self.hidden_size, self.vocab_size)\n",
    "\n",
    "    # 解码整个序列\n",
    "    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # 从<sos>开始解码\n",
    "        decoder_input = torch.empty(batch_size, 1, dtype=\\\n",
    "            torch.long).fill_(SOS_token)\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        attentions = []\n",
    "\n",
    "        # 如果目标序列确定，最大解码步数确定；\n",
    "        # 如果目标序列不确定，解码到最大长度\n",
    "        if target_tensor is not None:\n",
    "            seq_length = target_tensor.size(1)\n",
    "        else:\n",
    "            seq_length = MAX_LENGTH\n",
    "        \n",
    "        # 进行seq_length次解码\n",
    "        for i in range(seq_length):\n",
    "            # 每次输入一个词和一个隐状态\n",
    "            decoder_output, decoder_hidden, attn_weights = \\\n",
    "                self.forward_step(\n",
    "                    decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            attentions.append(attn_weights)\n",
    "\n",
    "            if target_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = target_tensor[:, i].unsqueeze(1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = topi.squeeze(-1).detach()\n",
    "\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        attentions = torch.cat(attentions, dim=1)\n",
    "        # 与RNNDecoder接口保持统一，最后输出注意力权重\n",
    "        return decoder_outputs, decoder_hidden, attentions\n",
    "\n",
    "    # 解码一步\n",
    "    def forward_step(self, input, hidden, encoder_outputs):\n",
    "        embeded =  self.embedding(input)\n",
    "        # 输出的隐状态为1 * batch * hidden_size，\n",
    "        # 注意力的输入需要batch * 1 * hidden_size\n",
    "        query = hidden.permute(1, 0, 2)\n",
    "        context, attn_weights = self.attention(query, encoder_outputs)\n",
    "        input_gru = torch.cat((embeded, context), dim=2)\n",
    "        # 输入的隐状态需要1 * batch * hidden_size\n",
    "        output, hidden = self.gru(input_gru, hidden)\n",
    "        output = self.out(output)\n",
    "        return output, hidden, attn_weights\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d63b031",
   "metadata": {},
   "source": [
    "\n",
    "接下来我们实现基于Transformer的编码器和解码器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c9a91cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append('./code')\n",
    "from transformer import *\n",
    "\n",
    "class TransformerEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_len, hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = EmbeddingLayer(vocab_size, max_len,\\\n",
    "            hidden_size)\n",
    "        # 直接使用TransformerLayer作为编码层，简单起见只使用一层\n",
    "        self.layer = TransformerLayer(hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size)\n",
    "        # 与TransformerLM不同，编码器不需要线性层用于输出\n",
    "        \n",
    "    def forward(self, input_ids):\n",
    "        # 这里实现的forward()函数一次只能处理一句话，\n",
    "        # 如果想要支持批次运算，需要根据输入序列的长度返回隐状态\n",
    "        assert input_ids.ndim == 2 and input_ids.size(0) == 1\n",
    "        seq_len = input_ids.size(1)\n",
    "        assert seq_len <= self.embedding_layer.max_len\n",
    "        \n",
    "        # 1 * seq_len\n",
    "        pos_ids = torch.unsqueeze(torch.arange(seq_len), dim=0)\n",
    "        attention_mask = torch.ones((1, seq_len), dtype=torch.int32)\n",
    "        input_states = self.embedding_layer(input_ids, pos_ids)\n",
    "        hidden_states = self.layer(input_states, attention_mask)\n",
    "        return hidden_states, attention_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "5e8cfc9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiHeadCrossAttention(MultiHeadSelfAttention):\n",
    "    def forward(self, tgt, tgt_mask, src, src_mask):\n",
    "        \"\"\"\n",
    "        tgt: query, batch_size * tgt_seq_len * hidden_size\n",
    "        tgt_mask: batch_size * tgt_seq_len\n",
    "        src: keys/values, batch_size * src_seq_len * hidden_size\n",
    "        src_mask: batch_size * src_seq_len\n",
    "        \"\"\"\n",
    "        # (batch_size * num_heads) * seq_len * (hidden_size / num_heads)\n",
    "        queries = self.transpose_qkv(self.W_q(tgt))\n",
    "        keys = self.transpose_qkv(self.W_k(src))\n",
    "        values = self.transpose_qkv(self.W_v(src))\n",
    "        # 这一步与自注意力不同，计算交叉掩码\n",
    "        # batch_size * tgt_seq_len * src_seq_len\n",
    "        attention_mask = tgt_mask.unsqueeze(2) * src_mask.unsqueeze(1)\n",
    "        # 重复张量的元素，用以支持多个注意力头的运算\n",
    "        # (batch_size * num_heads) * tgt_seq_len * src_seq_len\n",
    "        attention_mask = torch.repeat_interleave(attention_mask,\\\n",
    "            repeats=self.num_heads, dim=0)\n",
    "        # (batch_size * num_heads) * tgt_seq_len * \\\n",
    "        # (hidden_size / num_heads)\n",
    "        output = self.attention(queries, keys, values, attention_mask)\n",
    "        # batch * tgt_seq_len * hidden_size\n",
    "        output_concat = self.transpose_output(output)\n",
    "        return self.W_o(output_concat)\n",
    "\n",
    "# TransformerDecoderLayer比TransformerLayer多了交叉多头注意力\n",
    "class TransformerDecoderLayer(nn.Module):\n",
    "    def __init__(self, hidden_size, num_heads, dropout,\\\n",
    "                 intermediate_size):\n",
    "        super().__init__()\n",
    "        self.self_attention = MultiHeadSelfAttention(hidden_size,\\\n",
    "            num_heads, dropout)\n",
    "        self.add_norm1 = AddNorm(hidden_size, dropout)\n",
    "        self.enc_attention = MultiHeadCrossAttention(hidden_size,\\\n",
    "            num_heads, dropout)\n",
    "        self.add_norm2 = AddNorm(hidden_size, dropout)\n",
    "        self.fnn = PositionWiseFNN(hidden_size, intermediate_size)\n",
    "        self.add_norm3 = AddNorm(hidden_size, dropout)\n",
    "\n",
    "    def forward(self, src_states, src_mask, tgt_states, tgt_mask):\n",
    "        # 掩码多头自注意力\n",
    "        tgt = self.add_norm1(tgt_states, self.self_attention(\\\n",
    "            tgt_states, tgt_states, tgt_states, tgt_mask))\n",
    "        # 交叉多头自注意力\n",
    "        tgt = self.add_norm2(tgt, self.enc_attention(tgt,\\\n",
    "            tgt_mask, src_states, src_mask))\n",
    "        # 前馈神经网络\n",
    "        return self.add_norm3(tgt, self.fnn(tgt))\n",
    "\n",
    "class TransformerDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_len, hidden_size, num_heads,\\\n",
    "                 dropout, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = EmbeddingLayer(vocab_size, max_len,\\\n",
    "            hidden_size)\n",
    "        # 简单起见只使用一层\n",
    "        self.layer = TransformerDecoderLayer(hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size)\n",
    "        # 解码器与TransformerLM一样，需要输出层\n",
    "        self.output_layer = nn.Linear(hidden_size, vocab_size)\n",
    "        \n",
    "    def forward(self, src_states, src_mask, tgt_tensor=None):\n",
    "        # 确保一次只输入一句话，形状为1 * seq_len * hidden_size\n",
    "        assert src_states.ndim == 3 and src_states.size(0) == 1\n",
    "        \n",
    "        if tgt_tensor is not None:\n",
    "            # 确保一次只输入一句话，形状为1 * seq_len\n",
    "            assert tgt_tensor.ndim == 2 and tgt_tensor.size(0) == 1\n",
    "            seq_len = tgt_tensor.size(1)\n",
    "            assert seq_len <= self.embedding_layer.max_len\n",
    "        else:\n",
    "            seq_len = self.embedding_layer.max_len\n",
    "        \n",
    "        decoder_input = torch.empty(1, 1, dtype=torch.long).\\\n",
    "            fill_(SOS_token)\n",
    "        decoder_outputs = []\n",
    "        \n",
    "        for i in range(seq_len):\n",
    "            decoder_output = self.forward_step(decoder_input,\\\n",
    "                src_mask, src_states)\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            \n",
    "            if tgt_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = torch.cat((decoder_input,\\\n",
    "                    tgt_tensor[:, i:i+1]), 1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = torch.cat((decoder_input,\\\n",
    "                    topi.squeeze(-1).detach()), 1)\n",
    "                \n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # 与RNNDecoder接口保持统一\n",
    "        return decoder_outputs, None, None\n",
    "        \n",
    "    # 解码一步，与RNNDecoder接口略有不同，RNNDecoder一次输入\n",
    "    # 一个隐状态和一个词，输出一个分布、一个隐状态\n",
    "    # TransformerDecoder不需要输入隐状态，\n",
    "    # 输入整个目标端历史输入序列，输出一个分布，不输出隐状态\n",
    "    def forward_step(self, tgt_inputs, src_mask, src_states):\n",
    "        seq_len = tgt_inputs.size(1)\n",
    "        # 1 * seq_len\n",
    "        pos_ids = torch.unsqueeze(torch.arange(seq_len), dim=0)\n",
    "        tgt_mask = torch.ones((1, seq_len), dtype=torch.int32)\n",
    "        tgt_states = self.embedding_layer(tgt_inputs, pos_ids)\n",
    "        hidden_states = self.layer(src_states, src_mask, tgt_states,\\\n",
    "            tgt_mask)\n",
    "        output = self.output_layer(hidden_states[:, -1:, :])\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e58b811",
   "metadata": {},
   "source": [
    "下面以机器翻译（中-英）为例展示如何训练序列到序列模型。这里使用的是中英文Books数据，其中中文标题来源于第4章所使用的数据集，英文标题是使用已训练好的机器翻译模型从中文标题翻译而得，因此该数据并不保证准确性，仅用于演示。\n",
    "\n",
    "首先需要对源语言和目标语言分别建立索引，并记录词频。\n",
    "\n",
    "<!-- 代码来自：\n",
    "\n",
    "https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html <span style=\"color:blue;font-size:20px\">BSD-3-Clause license</span>\n",
    "\n",
    "下载链接：\n",
    "\n",
    "https://github.com/zwhe99/LLM-MT-Eval/blob/main/data/raw/wmt22.zh-en.en\n",
    "\n",
    "https://github.com/zwhe99/LLM-MT-Eval/blob/main/data/raw/wmt22.zh-en.zh\n",
    " -->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "6b258fc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "SOS_token = 0\n",
    "EOS_token = 1\n",
    "\n",
    "class Lang:\n",
    "    def __init__(self, name):\n",
    "        self.name = name\n",
    "        self.word2index = {}\n",
    "        self.word2count = {}\n",
    "        self.index2word = {0: \"<sos>\", 1: \"<eos>\"}\n",
    "        self.n_words = 2  # Count SOS and EOS\n",
    "\n",
    "    def addSentence(self, sentence):\n",
    "        for word in sentence.split(' '):\n",
    "            self.addWord(word)\n",
    "\n",
    "    def addWord(self, word):\n",
    "        if word not in self.word2index:\n",
    "            self.word2index[word] = self.n_words\n",
    "            self.word2count[word] = 1\n",
    "            self.index2word[self.n_words] = word\n",
    "            self.n_words += 1\n",
    "        else:\n",
    "            self.word2count[word] += 1\n",
    "            \n",
    "    def sent2ids(self, sent):\n",
    "        return [self.word2index[word] for word in sent.split(' ')]\n",
    "    \n",
    "    def ids2sent(self, ids):\n",
    "        return ' '.join([self.index2word[idx] for idx in ids])\n",
    "\n",
    "import unicodedata\n",
    "import string\n",
    "import re\n",
    "import random\n",
    "\n",
    "# 文件使用unicode编码，我们将unicode转为ASCII，转为小写，并修改标点\n",
    "def unicodeToAscii(s):\n",
    "    return ''.join(\n",
    "        c for c in unicodedata.normalize('NFD', s)\n",
    "        if unicodedata.category(c) != 'Mn'\n",
    "    )\n",
    "\n",
    "def normalizeString(s):\n",
    "    s = unicodeToAscii(s.lower().strip())\n",
    "    # 在标点前插入空格\n",
    "    s = re.sub(r\"([,.!?])\", r\" \\1\", s)\n",
    "    return s.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "01ae1e7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取文件，一共有两个文件，两个文件的同一行对应一对源语言和目标语言句子\n",
    "def readLangs(lang1, lang2):\n",
    "    # 读取文件，分句\n",
    "    lines1 = open(f'{lang1}.txt', encoding='utf-8').read()\\\n",
    "        .strip().split('\\n')\n",
    "    lines2 = open(f'{lang2}.txt', encoding='utf-8').read()\\\n",
    "        .strip().split('\\n')\n",
    "    print(len(lines1), len(lines2))\n",
    "    \n",
    "    # 规范化\n",
    "    lines1 = [normalizeString(s) for s in lines1]\n",
    "    lines2 = [normalizeString(s) for s in lines2]\n",
    "    if lang1 == 'zh':\n",
    "        lines1 = [' '.join(list(s.replace(' ', ''))) for s in lines1]\n",
    "    if lang2 == 'zh':\n",
    "        lines2 = [' '.join(list(s.replace(' ', ''))) for s in lines2]\n",
    "    pairs = [[l1, l2] for l1, l2 in zip(lines1, lines2)]\n",
    "\n",
    "    input_lang = Lang(lang1)\n",
    "    output_lang = Lang(lang2)\n",
    "    return input_lang, output_lang, pairs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "d6a561e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2157 2157\n",
      "读取 2157 对序列\n",
      "过滤后剩余 2003 对序列\n",
      "统计词数\n",
      "zh 1368\n",
      "en 3287\n",
      "['金 融 基 础 知 识 （ 附 微 课 第 4 版 ）', 'basics of finance (with micro-study version 4)']\n"
     ]
    }
   ],
   "source": [
    "# 为了快速训练，过滤掉一些过长的句子\n",
    "MAX_LENGTH = 30\n",
    "\n",
    "def filterPair(p):\n",
    "    return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
    "        len(p[1].split(' ')) < MAX_LENGTH\n",
    "\n",
    "def filterPairs(pairs):\n",
    "    return [pair for pair in pairs if filterPair(pair)]\n",
    "\n",
    "def prepareData(lang1, lang2):\n",
    "    input_lang, output_lang, pairs = readLangs(lang1, lang2)\n",
    "    print(f\"读取 {len(pairs)} 对序列\")\n",
    "    pairs = filterPairs(pairs)\n",
    "    print(f\"过滤后剩余 {len(pairs)} 对序列\")\n",
    "    print(\"统计词数\")\n",
    "    for pair in pairs:\n",
    "        input_lang.addSentence(pair[0])\n",
    "        output_lang.addSentence(pair[1])\n",
    "    print(input_lang.name, input_lang.n_words)\n",
    "    print(output_lang.name, output_lang.n_words)\n",
    "    return input_lang, output_lang, pairs\n",
    "\n",
    "input_lang, output_lang, pairs = prepareData('zh', 'en')\n",
    "print(random.choice(pairs))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfa6874b",
   "metadata": {},
   "source": [
    "为了便于训练，对每一对源-目标句子需要准备一个源张量（源句子的词元索引）和一个目标张量（目标句子的词元索引）。在两个句子的末尾会添加“\\<eos\\>”。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "92baa7c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2157 2157\n",
      "读取 2157 对序列\n",
      "过滤后剩余 2003 对序列\n",
      "统计词数\n",
      "zh 1368\n",
      "en 3287\n"
     ]
    }
   ],
   "source": [
    "def get_train_data():\n",
    "    input_lang, output_lang, pairs = prepareData('zh', 'en')\n",
    "    train_data = []\n",
    "    for idx, (src_sent, tgt_sent) in enumerate(pairs):\n",
    "        src_ids = input_lang.sent2ids(src_sent)\n",
    "        tgt_ids = output_lang.sent2ids(tgt_sent)\n",
    "        # 添加<eos>\n",
    "        src_ids.append(EOS_token)\n",
    "        tgt_ids.append(EOS_token)\n",
    "        train_data.append([src_ids, tgt_ids])\n",
    "    return input_lang, output_lang, train_data\n",
    "        \n",
    "input_lang, output_lang, train_data = get_train_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "411c9ed3",
   "metadata": {},
   "source": [
    "<!--训练时，我们使用编码器编码源句子，并且保留每一步的输出向量和最终的隐状态。然后解码器以“\\<sos\\>”作为第一个输入，以及编码器的最终隐状态作为它的第一个隐状态。-->\n",
    "接下来是训练代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8b85d1ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch-19, step=2002, loss=0.0094: 100%|█████████| 20/20 [25:00<00:00, 75.05s/it]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGwCAYAAACHJU4LAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABJ2UlEQVR4nO3dd1gU1/4G8HdpS186iCIoVkDRiBosUWNvMT0xTdPuNdca0zRV09B000xi8jMmuYneG0u8sZsoxoi9YQlWBBUEEVnqArvn9wdhZGEXdpdlZxfez/PwPDszZ2a/4xh5c+bMGYUQQoCIiIjIDjnJXQARERGRMQwqREREZLcYVIiIiMhuMagQERGR3WJQISIiIrvFoEJERER2i0GFiIiI7JaL3AU0hk6nw+XLl+Hj4wOFQiF3OURERGQCIQQKCwsRHh4OJ6f6+0wcOqhcvnwZERERcpdBREREFsjMzESbNm3qbePQQcXHxwdA1Yn6+vrKXA0RERGZQq1WIyIiQvo9Xh+HDirVt3t8fX0ZVIiIiByMKcM2OJiWiIiI7BaDChEREdktBhUiIiKyWwwqREREZLcYVIiIiMhuMagQERGR3WJQISIiIrvFoEJERER2i0GFiIiI7BaDChEREdktBhUiIiKyWwwqREREZLcYVIwoLddCCCF3GURERC0ag4oBF/KK0fXVjZi14rDcpRAREbVoDCoGfLsrHQDwy+HL8hZCRETUwjGoGODmwj8WIiIieyDrb+TKykq8/PLLaNeuHTw8PNC+fXu8/vrr0Ol0cpYFZ4VC1u8nIiKiKi5yfvnChQvxxRdfYNmyZYiNjcX+/fvx6KOPQqVSYebMmbLV5ezEoEJERGQPZA0qKSkpmDBhAsaOHQsAiIqKwk8//YT9+/cbbK/RaKDRaKRltVrdJHU5sUeFiIjILsh662fAgAH47bffcOrUKQDAkSNHsHPnTowZM8Zg+6SkJKhUKuknIiKiSeqKDPRskuMSERGReWTtUXnhhRdQUFCALl26wNnZGVqtFm+99RYmTpxosP3cuXMxe/ZsaVmtVjdJWEmMDgQAuDqzZ4WIiEhOsgaVFStW4IcffsCPP/6I2NhYHD58GLNmzUJ4eDgmTZpUp71SqYRSqWzyulydqzqaKrQCQggoeCuIiIhIFrIGleeeew5z5szB/fffDwDo1q0bLly4gKSkJINBxVaqgwpQFVbcXBhUiIiI5CDrGJWSkhI4OemX4OzsLPvjyW56QUXeWoiIiFoyWXtUxo8fj7feegtt27ZFbGwsDh06hA8++ACPPfaYnGXpjU1hUCEiIpKPrEHlk08+wSuvvIJ//etfyMnJQXh4OP75z3/i1VdflbMsODspoFAAQgDlDCpERESykTWo+Pj44KOPPsJHH30kZxl1KBQKuDo7obxShwot36BMREQkF77UxojqcSoVlexRISIikguDihHV41Q4RoWIiEg+DCpGuPzdo8IxKkRERPJhUDGi+g3KgkNUiIiIZMOgYkT1C5S1OiYVIiIiuTCoGOH0d1LRsUuFiIhINgwqRjgpqoOKzIUQERG1YAwqRlTf+hHsUSEiIpINg4oR7FEhIiKSH4OKEX/nFI5RISIikhGDihE3elQYVIiIiOTCoGKEE+dRISIikh2DihHVjydzHhUiIiL5MKgY4cQxKkRERLJjUDGCt36IiIjkx6BiBHtUiIiI5MegYoSC86gQERHJjkHFCPaoEBERyY9BxYgbY1QYVIiIiOTCoGLEjbcny1wIERFRC8agYkT1rR/Oo0JERCQfBhUjOIU+ERGR/BhUjOA8KkRERPJjUDGCb08mIiKSH4OKEU6cR4WIiEh2DCpGcB4VIiIi+TGoGOHsxHlUiIiI5MagYgSn0CciIpIfg4oRnEeFiIhIfgwqRnAKfSIiIvnJGlSioqKgUCjq/EydOlXOsgDwqR8iIiJ74CLnl+/btw9arVZaPnbsGIYPH4577rlHxqqqcB4VIiIi+ckaVIKDg/WWFyxYgOjoaAwaNEimim5gjwoREZH8ZA0qNZWXl+OHH37A7NmzpSduatNoNNBoNNKyWq1usnr4eDIREZH87GYw7Zo1a3D9+nVMnjzZaJukpCSoVCrpJyIiosnq4a0fIiIi+dlNUPnmm28wevRohIeHG20zd+5cFBQUSD+ZmZlNVg9v/RAREcnPLm79XLhwAVu3bsWqVavqbadUKqFUKm1SE+dRISIikp9d9KgsXboUISEhGDt2rNylSDiPChERkfxkDyo6nQ5Lly7FpEmT4OJiFx08ADiFPhERkT2QPahs3boVGRkZeOyxx+QuRQ/fnkxERCQ/2bswRowYYZe3V6ofT2aPChERkXxk71GxVxyjQkREJD8GFSM4jwoREZH8GFSM4DwqRERE8mNQMUIaTMukQkREJBsGFSNu9KgwqBAREcmFQcUIzqNCREQkPwYVI5z//pNhjwoREZF8GFSMKCitAAAs/TNd3kKIiIhaMAYVI37YnSF3CURERC0egwoRERHZLQYVE3B2WiIiInkwqJiAOYWIiEgeDCom4JM/RERE8mBQMSLM1136XMnJVIiIiGTBoGLE15MSpM/b03JkrISIiKjlYlAxIqaVr/R5yg8HUaypRGFZhYwVERERtTwuchdgr5yq30r4t9jXNgEA0t4cBaWLsxwlERERtTjsUTFTbqFG7hKIiIhaDAYVM1W/rJCIiIiaHoMKERER2S0GFSIiIrJbDCpERERktxhUiIiIyG4xqBAREZHdYlAxE9+kTEREZDsMKmZKvVggdwlEREQtBoNKPQK83Oqse+rfB2WohIiIqGViUKnHT0/eLHcJRERELRqDSj0iAz3lLoGIiKhFY1Cph7MTp8snIiKSk+xB5dKlS3jooYcQGBgIT09P9OjRAwcOHJC7LACAC4MKERGRrFzk/PL8/Hz0798fQ4YMwYYNGxASEoKzZ8/Cz89PzrIkfAEhERGRvGQNKgsXLkRERASWLl0qrYuKipKvIAsJIaAurYTK01XuUoiIiJoVWW/9rF27FgkJCbjnnnsQEhKCnj17YsmSJUbbazQaqNVqvR978PzPRxH/+mbsPpcndylERETNiqxB5dy5c1i8eDE6duyITZs2YcqUKZgxYwa+++47g+2TkpKgUqmkn4iICBtXrG/HqVxkXivBfw9cBADc/9VuWeshIiJqbhRCxjnh3dzckJCQgF27dknrZsyYgX379iElJaVOe41GA41GIy2r1WpERESgoKAAvr6+TVJj1Jx1ddbteXEozuYW4YEle+psS18wtknqICIiai7UajVUKpVJv79l7VFp1aoVYmJi9NZ17doVGRkZBtsrlUr4+vrq/cjhjs/+RMpZ3uYhIiJqarIGlf79+yMtLU1v3alTpxAZGSlTRaa5XFCGT34/I3cZREREzZ6sQeXpp5/G7t278fbbb+PMmTP48ccf8dVXX2Hq1KlylkVERER2Qtag0rt3b6xevRo//fQT4uLi8MYbb+Cjjz7Cgw8+KGdZREREZCdknUcFAMaNG4dx48bJXQYRERHZIdmn0CciIiIyhkGlAY8kmjewN/1qcRNVQkRE1PIwqDTg9QlxZrUf/N52o9vKK3XQVGobWREREVHLwaBiIzqdQN+3t+Km17egUquTuxwiIiKHwKBigoEdgxp9jJIKLfJLKlBcrsWVQk3DOxARERGDiil0Zr5l4PL10jrrFDU+y/jWAiIiIofCoGICZyfz/pj6Lfi9zjpFjaTCnEJERGQaBhUTzBsfg3CVu1n7LNz4F58AIiIiaiQGFRO0D/bGrrlDzdpn8fazek8AKfRu/hAREZEpGFRsRODG/R7e+iEiIjINg4qN1AwnNUMLERERGcegYiMFpRXS59NXimSshIiIyHEwqDSx71LScfTidbi7Okvrvtl5XsaKiIiIHIfsb09u7l795TgAYN9Lw6R1Kefy5CqHiIjIobBHxUa0Oo5LISIiMheDio1o+agPERGR2RhUzODmbPkfV0ZeiRUrISIiahkYVMzw55xbLd534pLdVqyEiIioZWBQMUOwjxL+nq5yl0FERNRiMKiYSaHgVPhERES2wqBCREREdotBhYiIiOwWgwoRERHZLQYVM3GEChERke0wqBAREZHdYlAhIiIiu8WgYqZ/DmovdwlEREQtBoOKmZ4c2B6/Th8gdxlEREQtAoOKmRQKBeJaq9Ap1BsA0CvS36Lj6Pg2ZSIioga5yF2Ao/rxyZux9cQVjI8Ph5fSBdvTcjB56T6T9//54EXcmxDRhBUSERE5PvaoWCjIW4n7+7SFl7Iq6w3uHGLW/iln85qiLCIiomZF1qAyb948KBQKvZ+wsDA5S7IZzsdCRETUMNlv/cTGxmLr1q3SsrOzs4zV2E5ukUbuEoiIiOye7EHFxcXF5F4UjUYDjebGL3i1Wt1UZTU5TYVO7hKIiIjsnuxjVE6fPo3w8HC0a9cO999/P86dO2e0bVJSElQqlfQTEeG4g1EVvPdDRETUIFmDSt++ffHdd99h06ZNWLJkCbKzs9GvXz/k5RkeaDp37lwUFBRIP5mZmTau2Hou5JXIXQIREZHdk/XWz+jRo6XP3bp1Q2JiIqKjo7Fs2TLMnj27TnulUgmlUmnLEq1C5eGKgtIKvXXZ6jKZqiEiInIcst/6qcnLywvdunXD6dOn5S7FIsZmrA30drNxJURERM2DXQUVjUaDkydPolWrVnKXYpG41irMGtZRb12IjxJfPtQLm5++RaaqiIiIHJesQeXZZ59FcnIyzp8/jz179uDuu++GWq3GpEmT5CyrUWYN66S3vPKpfugY6oNOoT5w4gBaIiIis8gaVC5evIiJEyeic+fOuPPOO+Hm5obdu3cjMjJSzrKsqubTPVtmD5I+B/s43lgbIiIiW5N1MO3y5cvl/Pom17WVL1r7eUjL0cHe0ufwGuuJiIjIMNknfGuO9r40FIVllXrBpLbb4sNtWBEREZFjsqvBtM1FiI+70ZAytlvVQGEXDlghIiJqEIOKjVWPWXlt7XFc5ft+iIiI6sWgYmOKGqNrE97cWk9LIiIiYlCxMd7wISIiMh2Dio1xaAoREZHpGFRsTMHXJhMREZmMQcXGGFOIiIhMx6BiY7V7VGYtP2TW/kcyr+Pt9SdRpKmU1s1dlYqk9SetUh8REZE94YRvNlb7zs+aw5dx9GIBxsWHY/bwToZ3qmHCZ38CAMordZh3Wywy8krw094MAMDzo7rAmYNgiIioGWGPio0ZihHnrhbj499OY1/6NRTX6CmpSasT+GrHWWl58/FsAICmUtsUZRIREdkFBhUbc6pnMO09X6Qg9rVNOH+1uM62NYcu4e31f0nLlwvKAACiRhudECAiImpOGFRszJSHfn7cc6HOurO5RXXWpWUX4oq6TFpOvVQArY5hhYiImg+OUbExU4KKU61xJtvScnAut24vy8iPdugt3/n5Ljx8cyTeuD2uUTUSERHZCwYVGzNlHhVnhQIbj2WhpFyLEB93PLp0n8nH/373BQYVIiJqNhhUbMyUZ3KKNZWY8sPBJq+FiIjI3nGMio2ZcutnWUrdMSpEREQtEYOKjSk4Ny0REZHJGFSIiIjIbjGo2JgAHx8mIiIyFYMKERER2S0GFRvj5LFERESmsyioLFu2DOvWrZOWn3/+efj5+aFfv364cIFPrNSHOYWIiMh0FgWVt99+Gx4eHgCAlJQUfPrpp3jnnXcQFBSEp59+2qoFEhERUctl0YRvmZmZ6NChAwBgzZo1uPvuu/GPf/wD/fv3x+DBg61ZX7Oz68xVuUsgIiJyGBb1qHh7eyMvLw8AsHnzZgwbNgwA4O7ujtLSUutV1wyl55XIXQIREZHDsKhHZfjw4XjiiSfQs2dPnDp1CmPHjgUAHD9+HFFRUdasj4iIiFowi3pUPvvsMyQmJiI3NxcrV65EYGAgAODAgQOYOHGiVQskIiKilsuiHhU/Pz98+umnddbPnz+/0QURERERVbOoR2Xjxo3YuXOntPzZZ5+hR48eeOCBB5Cfn2+14pqjvu0C5C6BiIjIYVgUVJ577jmo1WoAQGpqKp555hmMGTMG586dw+zZsy0qJCkpCQqFArNmzbJof0fRMdRb7hKIiIgchkW3fs6fP4+YmBgAwMqVKzFu3Di8/fbbOHjwIMaMGWP28fbt24evvvoK3bt3t6Qch8KZaYmIiExnUY+Km5sbSkqqHrPdunUrRowYAQAICAiQelpMVVRUhAcffBBLliyBv7+/JeU4FB2DChERkcksCioDBgzA7Nmz8cYbb2Dv3r3S48mnTp1CmzZtzDrW1KlTMXbsWGkulvpoNBqo1Wq9H0fTNsBT7hKIiIgchkVB5dNPP4WLiwt+/vlnLF68GK1btwYAbNiwAaNGjTL5OMuXL8fBgweRlJRkUvukpCSoVCrpJyIiwpLyZfVo/yg81r+d3GUQERE5BIUQ8oyayMzMREJCAjZv3oz4+HgAwODBg9GjRw989NFHBvfRaDTQaDTSslqtRkREBAoKCuDr62uLsq0mas66hhtZKH3B2CY7NhERUWOp1WqoVCqTfn9bNJgWALRaLdasWYOTJ09CoVCga9eumDBhApydnU3a/8CBA8jJyUGvXr30jrljxw58+umn0Gg0dY6lVCqhVCotLZmIiIgcjEVB5cyZMxgzZgwuXbqEzp07QwiBU6dOISIiAuvWrUN0dHSDxxg6dChSU1P11j366KPo0qULXnjhBZMDDxERETVfFgWVGTNmIDo6Grt370ZAQNUEZnl5eXjooYcwY8YMrFvX8G0NHx8fxMXF6a3z8vJCYGBgnfVERETUMlkUVJKTk/VCCgAEBgZiwYIF6N+/v9WKI9sqr9Qh5Vweekf5w9PN4ruCREREVmPRbyOlUonCwsI664uKiuDm5mZxMdu3b7d4X0fzxUO9MOWHA3KXoeft9Sfx7a50DO0Sgm8m95a7HCIiIsseTx43bhz+8Y9/YM+ePRBCQAiB3bt3Y8qUKbjtttusXWOzNCouTO4S6liWkg4A+O2vHHkLISIi+ptFQeXjjz9GdHQ0EhMT4e7uDnd3d/Tr1w8dOnQw+mgxmSbtTdPnobE2Tu9PRET2xqJbP35+fvjll19w5swZnDx5EkIIxMTEoEOHDtaur8VRuvBpJyIiomomB5WG3opcc3zJBx98YHFBRERERNVMDiqHDh0yqZ1CobC4GCIiIqKaTA4q27Zta8o6yIgRMaHYfOJKnfW39wjHmsOXZaiIiIjIdiwaTEu289UjCQbXf3hfD2yYOdBq3yPTK5+IiIjqxaAio3fv7m7xvgqFAioPV6vV8t7mNKsdi4iIyFoYVGR0T0IEVj6VaHT7yNjQevd3suJ4oM+2ndVbLimvtNqxiYiILMWgIrNekQF11rUL8gIA3N6jdb37OhnIKaG+1nm79NiPd1rlOERERI3BF7rYof9NH4BzuUXo1lpVbztNpa7OOnOGmmgqtXjo6z1IiKobls5fLTb9QERERE2EQcUOeStd0L2Nn0X76swIKhuPZWNfej72pedb9F1ERERNjbd+7MDqf/XDyNhQJD832OD2oV1CTD6WKU/vZF4rwd7z11BWoTX5uERERHJgj4od6NnWH18+bPgxZAD49IGbsP/CNaRlF+LNdSel9TXH0g7rGoKtJ3PQUEzJLijDwHeq5sR5+ObIxpRNRETU5BhUHICHmzMGdgxGv+gghPi6IyHSHwDg6Xbj8ildq94RpGugR+XmpN+kz5uOZ5v0/e9vTkOEvyfu7R1hbulERESNwls/DsTZSYHb4sMR7ucBAAjwcpO2xbTyBQDozBikklOoabDNsUsF+OT3M3h+5VEzqyUiImo8BpVmwtOtqkelZofK0YvX8fzPR5BTWGbxcV9andrY0oiIiCzGWz/NhPPfk6rUvPVz26d/AgByCzVY+mgfs3pbqh25WCB9rtDq4OrMbEtERLbD3zrNhNvfAaK4XIuzuUV6245fVuP81eIGx680ZMMx08a0EBERWQuDioOb3C8K/TsEol90kLRu6PvJUJdVSMs5hRoMeW87fvsrR29fF0NT29ajqIzT6hMRkW3x1o+Dm3dbLADg8vVSvfXrj2bVafvT3gy95UozbwWZmWuIiIgajT0qzUTtFxTOWVV3EOz2tFyzjnkmp6jhRkRERE2IQaWZaIrejmEfJOstK13514WIiGyLv3maCYWi6e/L+Hu6NdyIiIjIihhUmglbjB+pOcEcERGRLTCoNBO26FFp5NPNREREZmNQaSZs0aPCnEJERLbGoNJM2KZHhVGFiIhsi0GlmWCPChERNUcMKs1E7XlUmkJ+cXmTfwcREVFNDCrNhC2CStKGv5r8O4iIiGqSNagsXrwY3bt3h6+vL3x9fZGYmIgNGzbIWZLDskFOwYW84qb/EiIiohpkDSpt2rTBggULsH//fuzfvx+33norJkyYgOPHj8tZlkOyRVCxxYBdIiKimmR9KeH48eP1lt966y0sXrwYu3fvRmxsbJ32Go0GGo1GWlar1U1eo6Owxa0fxhQiIrI1uxmjotVqsXz5chQXFyMxMdFgm6SkJKhUKuknIiLCxlXaL1fnpr+U7FAhIiJbkz2opKamwtvbG0qlElOmTMHq1asRExNjsO3cuXNRUFAg/WRmZtq42pbt/t5t5S6BiIhaGFlv/QBA586dcfjwYVy/fh0rV67EpEmTkJycbDCsKJVKKJVKGaokAGjj7yF3CURE1MLIHlTc3NzQoUMHAEBCQgL27duHRYsW4csvv5S5MqpNx5lpiYjIxmS/9VObEEJvwCzZDx1zChER2ZisPSovvvgiRo8ejYiICBQWFmL58uXYvn07Nm7cKGdZZAR7VIiIyNZkDSpXrlzBww8/jKysLKhUKnTv3h0bN27E8OHD5SyLjNCxS4WIiGxM1qDyzTffyPn1ZCbmFCIisjW7G6NClvvg3vhGH6NzqI/Rbbz1Q0REtsag0ozceVObRh/jf9MHGN3GHhUiIrI1BpUW7OlhnaTPd/dqgz/n3Ao3F+N/JThGhYiIbI1BpZn54qGbTG47c1hH6XN8hB9a+1VN6DZjaEeD7Rt76+dklhq/nbzSqGMQEVHLwqDSzIyKa4UTr49ssF37YC+9ZRenGy/ymWU0qDSuttGL/sDjy/bjjV9PNO5ARETUYjCoNEOebnUf5urfIVD63K21Cl881Etve6dQb+mzk5MCQd5udY4h6ulR2XQ8G8cuFZhU3zc7z2P3uTyT2hIRUcvGoNJC+HncCB7/mz4Anf5+uueXqf2x6P4e6BUZoNfe0KBaY7d+TlxW45/fH8C4T3aaXM9b605CXVZhcnsiImqZGFRagHfu7g6nGrd2aoqP8MOEHq3rrG+l8sDjA9rprTN060dTqcWJLHWDNZSWa/WWUy8VoPu8zQ3uR0RELZvsLyWkpuGjdEGhphIAcG9CBPpFB+L4pQJM7h9l8jHmju6C0XFh2HziCr7acQ7aWknlnY1/4fPtZ/XWnc0tQkZeCdRlFegc5oMuYb4AgOEfJhv8Dp1OGA1RREREDCrN1L29I/DNzvPScht/T/z+7GCzjuHi7ISEqAAkn8oFUHeMSu2QAgBD39cPJOkLxiK7oAwX80sNfkelTsCNQYWIiIzgrZ9mysvN2WrHUiiqgoSlT/3cnPSb0W2pl65bdlAiImoRGFSaqdt7Vo07iY/wa/SxnKWgYn5SeW9TWr3b71qcgrTsQovqIiKi5o9BpZlqH+yNw68Ox6qn+jX6WNV3ZiwJKp9uO9Ngm5Ef7TD7uERE1DJwjEoz5udZdy4US1QPdtXprHI4IiIik7FHhRqkaESPChERUWMwqFCDnBo5mJaIiMhSDCrUoOoxKvVNoU9ERNQUGFSoQdU9KtnqMuxLvyZzNURE1JIwqFCDqoPKrrN5uOeLFPxy+JLMFRERUUvBoEINqj1x7Mzlh6HjgBUiIrIBBhVqUO138Yzr3gq93twiUzVERNSScB4VatAVdZne8q9Hs2SqhIiIWhr2qFCDPttW9+WDREREtsCgQkRERHaLQYWIiIjsFoMKyW5gxyC5SyAiIjvFoEKyc3PmX0MiIjKMvyFIdlpOzU9EREYwqJBNrHyqn9FtWk4eR0RERjCoUKNFB3s12Ma59vS2NRzJvG7FaoiIqDmRNagkJSWhd+/e8PHxQUhICG6//XakpaXJWRIZMCImtN7tvz0zuMFjhPoqjW5Tl1WaW5KeAxeu4fuUdL7dmYioGZI1qCQnJ2Pq1KnYvXs3tmzZgsrKSowYMQLFxcVylkW1zB3T1ex93Fz0/2r5uruijb+HtUqS6HQCdy1OwSu/HMfvf+VY/fhERCQvWYPKxo0bMXnyZMTGxiI+Ph5Lly5FRkYGDhw4IGdZVEu7IOO3dgK93AAAseG+eusPvjJcb1mhAML9rB9UrpWUS5/XcWp/IqJmx67GqBQUFAAAAgICDG7XaDRQq9V6PySv3/++7ZN0Zze99V5uznrLCigavIVkiZo9NzvPXLX68YmISF5281JCIQRmz56NAQMGIC4uzmCbpKQkzJ8/38aVkTFPDY6GytMVAOCk0B8sq1AokNg+ECnn8qR1k/tFISLAEx1DvFGs0eJgRj5eW3u8UTUI3Y3POYUavLXuBFLO5eHnKf3g7upsfEciInIIdtOjMm3aNBw9ehQ//fST0TZz585FQUGB9JOZmWnDCqk2lxpP8nQI8ZY+e/7dmzJndBf99s5OGBkbhvbB3ujWRgX/v28bAbB4IKyuxn4hPkos+eM8jl1S8w3PRETNhF30qEyfPh1r167Fjh070KZNG6PtlEollErjT4+QbdV85Lhm70X1oNn4CD/c3asNPFyd4eFWt3ejvPJGd4imUmdRD0jNyeJyCjXS57wijaHmRETkYGTtURFCYNq0aVi1ahV+//13tGvXTs5yyEyX8ksNrg/0uhEm37snHm/cbvhWXiuVu/Q55WyewTYAkFuowc7TVw32uuiM9MSkXirAr0cvo0KrM7idiIgcg6xBZerUqfjhhx/w448/wsfHB9nZ2cjOzkZpqeFfgGRf/nvgosH1AqbdxukV6S991lRq62wv0lRi8tK96P3WVjz0zR5sOJYNACgsq0DyqVxUanXQGckhvx7NwrQfD+HrP86bVAsREdknWYPK4sWLUVBQgMGDB6NVq1bSz4oVK+Qsi2zEtcbLCMu1dcPNN3+cx/a0XGl564krAIBH/m8vJv3fXny+/azRHpVqCzf+ZaVqiYhIDrLf+jH0M3nyZDnLokYydVxszVn1g2oMrK2mLqvQW1516BIA4FDGdQDAB1tO8T1BRETNnN089UMtj6LGI81+nnWDSn3vB6o28J1tVq2JiIjsC4MKWay9kZcRmtPHUT2g1tCg19pzswBAsaZx7wUiIiLHwqBCFvv4/p6GN5iRVLIKygAA564WSetKy7V4esVhbD6RrdfW1VmBx5ftM7tOIiJyXHYxjwo5lkOvDEeRphIRAZ5WO+bTK45gdFwruLs64/PtZ7D67/EoNVVoBXafu2bWccN83RtuREREdos9KmQ2fy+3ekOKqY8n1/b0isMQQuCT389YWlodPdv6We1YRERkewwqZDc2HMvG2dxiqx7TyYQBuUREZL8YVMhqnhneCS5OCrw6LtbiY1j7ceMAA08TERGR42BQIZM8NTgaQNUbkI2ZPrQjTr4xCt3aqCz+Hmt3gPh6cBgWEZEj47/iZJLnRnTG+O7h6BzmU2+7mrPNWmL4hzsatX9tfNUPEZFjY48KmcTJSYGYcF+TJmGzVLsgw/OyNIahFxkSEZHjYFAhuzGoU7DVj8kp9omIHBuDCtkNpYv1/zoypxAROTYGFbIb36VcsPoxG3q7MhER2TcGFbIbpRVaqx+Tt36IiBwbgwo1a9bqUbl8vRQ/7slAYVmFVY5HRESm4ePJJKshnYOxLS23yY5vjaDyx+lcPPzNXgDAn2ev4rMHbmr0MYmIyDTsUSFZuTRy3pXaogL130Gks8I8KtUhBQDWHc1q/AGJiMhkDCokq95R/hbv2ycqAAAwPj5cWqdQKPDH80MwpHPVo87aJhhMeza3yOrHJCIiwxhUSFaP9m9n8b739Y7Ar9MH4IN746V1CgUQEeCJxOhAAICuCQbTDn0/mRPJERHZCIMKycrV2QluJtz+CfVV1lnn6+GKuNYqvWn7q+fNdVJUfTJ1jMrRi9dx4EI+RnyYjJfXpDbYng8TERHZBoMKyW5Ax6B6t0cEeODFMV3rrK8ZXryVVePC+0VXHas6qGhNCBSFZRW47dM/cdfiXTh1pQg/7M7AkPe219sbw8eeiYhsg0GFZDe0a0i921UerlLwqObm7IROoTdekLh+xkC8PLYr5o7pAuDGW5h1QkAIgZNZapSWG56nJb+47iPH568WY+HGv4zWxInkiIhsg0GFZPdAn7Z11r11R5z0WQEFaseCo/NGwN3VWVpuG+iJJwa2h6dbVc9K9csTdTqB307mYPSiP3Bz0m9m1fXljnPYcy7P4La84nKzjkVERJZhUCHZKRR138hc803KAkJv8Kq7q5NeSKnvmFqdwBPf7QcAFJRW4NilArNu29z31W6D69ccumTyMYiIyHKc8I3sknON8CIEENPKV1p+ycB4lTr7V/eo1Mok4z7ZCQBIXzBWWmcgJzWoWFNp/k5ERGQ2BhWyS05ON9KDTgAdQ31wS6dgnLlSiHsSIhrc37mBp35KyivxyDd7kZZdiEILQsfn28/i+VFdzN6PiIjMw6BCdifYR4lurVXS8k1t/QAA3z3Wx+RjVPeSGLvNM3HJHhzJvG5piUREZCMco0J2J/m5wXB3dcbGWQMxa1hHzDXhVk9tN279CAztUvepIoYUIiLHwB4VsjvVT+50CfNFlzDfBlobVnPCt9/+yrFabUREZFvsUaFmqXqMy59nDD9eTEREjkHWoLJjxw6MHz8e4eHhUCgUWLNmjZzlkIzm3xYLAPjHLe2tcjxnSx7lISIiuyPrrZ/i4mLEx8fj0UcfxV133SVnKSSzSf2icE9CG+m2T2NVaHVWOQ4REclL1qAyevRojB49Ws4SyI5YK6QAwOYT2VY7FhERycehBtNqNBpoNBppWa1Wy1gN2bP1qQwqRETNgUMNpk1KSoJKpZJ+IiIanviLyBr2vTQMq/7VDx1CvOUuhYioRXGooDJ37lwUFBRIP5mZmXKXRHYqyNvNqscL9lHiprb++OnJm6V1wsI3KFdqddDqBK4WafTWCyGgM+M9RERELYFD3fpRKpVQKpVyl0EOIMTHHVeLrP+GYzfnG9m+Uifg6lz36aLSci3WHrmE23u2htJF/+WJz/98BP/ZfxHtg71wLrcYv0ztj/gIPwDA5KX7cPl6KdbPHAhXZ4f6fwgioibDfw2pWXp+VGeL923t52F0m6vLjWBi7Mmirq9uxAsrU9H55Y11tv1n/0UAwLncYgDAt7vSAVRN9Z98Khenc4pw9GKBpaUTETU7svaoFBUV4cyZM9Ly+fPncfjwYQQEBKBt27YyVkaOLjrYsrEkQd5KrJsxAIt+O43YcBW+2Xkez47oJG13cbqR7Su0Dd+mSTmbh25tVPByc8b9X+2us33HqVycvlKIR/5vr7TO2YlzwBARVZM1qOzfvx9DhgyRlmfPng0AmDRpEr799luZqqLmwMfd/L/ad/RsjdnDO8HP0w2vja+agO7uXm302tS81WPKXC3bT+Vg4pLdGNMtDHvOX6uzPa+4HMM/3KG3jjmFiOgGWYPK4MGDLR6QSFQfDzfnhhvVkL5grEntFAoFXJ0VqNAKaCp1OJdbhKhAL2nK/tq+TD4HwLzHpYs0lSa3JSJq7jhGhZolNwODUd+9u7tVjl090HXOyqO49f1kPPndfvzfzvOImrMOu85cbfTxP9h8qtHHICJqLhhUqFlSGHjXzz0J1pl3x+Xv3pM/TleFkt/+ysHrv54AADzw9Z5GH/9kFicyJCKqxqBCLcrmp2/Bssf66K2b3C/KrGOoy5r21oyWt0OJiCQMKtSidAr1waBOwTjy6ggAwIiYUMz7+83N9qJ3VIDcJRAR2Q2HmvCNyFpUnq4mD6C1tepbSuYSQqC4XAtvZdV/1rmFGqzYl4EOIT4Y0DFIWk9E5Ej4Lxc1W/+dkoh7vkgBAIT6Nv8Zjaf8cACbjl/B6n/1Q8+2/uj79lbUnJHfXoMZEVF9GFSo2eodFYA/59yKJTvO4dH+UXKX0+Q2Hb8CALjj810yV0JEZD0MKtSstfbzsLsxKEREZDoOpiWyM/FtVHKXQERkNxhUiOzE23d0AwCE+Lob3L760EV8s/O8wW1lFdoGj29KGyIie8OgQmSmYV1DrHKcfz/RV/r845N9pfcIGXqHkE4n8PSKI3jj1xP49ehlbPsrR6/dA0vqvvCwttjXNlmhaiIi2+IYFSIzTezTFltP5pjcPsjbDVeLyvXWJbYPRL/oQPzzlvboEOKNftFB+OXwJQBApYG3Mmery6TP0348JH2ODvbCkkcScDDjeoN1aHWcSI6IHA+DCpGZDL0FuT7rZw7ExK9242xuMT6Z2BMDOwZB5eEKhUKBuWO6Su2c/56a/+jF63r7n8xSY/SiPwwe+2xuMW59P9m8E7CiD7acglanw3Mju8hWAxE1b7z1Q2SmxPaBddbte2mY0faebi747ZnBSF8wFuPjw+Hn6WbwXURbTlQ9Xlxziv607EJ8v/uCFaquUl5Z97aSOdYdzcLb609CpxMoLKvAx7+dxmfbziLzWomVKiQi0sceFSIzRQd7S5/DVe549554BPvoTyg3KTESy1KqAkb1Swwb4uJ04/8bFm78C7+fzEHalUIrVHxDSXkl3FzcLN5/6o8HAQA3tfVHzayVea0EEQGejS2PiKgO9qgQmSnI58Yv+h+fvBn9OwQBAF4e2xU92/ohdd4I9G534309ziYGlSDvG8ddvP2s1UMK0LhxKr8evSx9zi3S4J/fH5CWzRmzQ0RkDgYVIjN5uDqjZ1s/dAr11utFeGJge6z+V3/4uLsitMYjxqb2qIzp1srqtdZWbuCJIkOEEKjQ6rDq4EWM/2QnLl0v1RvE++OeDL32+9LNG7dDRGQq3vohMpNCocDKKf0gYLy3JL6NHyIDPdHG38PgeBRDPN2crVilYZoK04LKP74/gEMZ+dLTSv0X/K63/WSWWm859VKBdQokIqqFPSpEFnByUtR7S8fNxQm/zR6EHx7va7RNbVrRNI8P73lxqPS5dqA4k1OIqDnr8MSyfXrrt5y4UueRaiIiOTCoEDURF2cnk3tTAKBTiI9Zx48MrDt49cUxXXBHz9Z662rehgqpNeh32Ac7AFSNMflwyyn8la3Gudwis+ogImpKvPVDZCecTBzLUm37s4Px+LL9OHrxOqKDvfH8qM7oFVk1iHf1oUt6bbu28sXJLDVOXSnEhWsluOumNnV6hBb9dhqLfjttUe0DOwZZtF9tl66X4r/7MzGmWyuM+HAHBnUKxrLH+ljl2ETkmBhUiOzI8n/cjPu/Mjwd/uanb8HSP8/jp72Z+Oeg9lAoFPi/yb0Ntl10fw/M/s8RfDqxJwDA3bWq8/SVX44DAJ7/+ahV6/7j9NUG26RlF2LkRzswKjYMb90RB02lDuF+HnptqsfCfLS1KjAln8q1ap1E5HgUQjTRjXEbUKvVUKlUKCgogK+vr9zlEFmFVidwIa8YCoUCM5cfwqBOwXhmRGezj1Oh1cHVuSqgPLBkN3adzbN2qXrSF4zFsUsFcHNxQqdQH5RVaPH+5jT06xAEpYsTHliyx+B+faICMHdMF1wvqcCj3+6rsz19wdgmrZuIbM+c39/sUSGyM85OCrT/e1K5tdMGWHyc6pACAIU1ZrttCoFebsgpLMO4T3YCAP56YxSm/3QIW05cwZI/DL/xudre9Gu44/NdRrfrdMKk22InLquxOPksnhneCUcuXscXyefwxUM3oY2/J07nFGL90Sz8c1A0fj16GcNjwhDgZfnEd0RkOwwqRC2Al9I6jz4P7BiEYG8lVv09BubR/lFY+mc68orLkVKjx2bo+8m4dL3UKt9ZVqmFp1vD/1SN+bjqfUj/O3JjYrpB727Xa/Px72cAAC+sTMXhV4fDz9N6YSXzWgkGvrMNXVv5YsPMgVY7LlFLx6d+iFoAU96uXNuMWzvgyKsj4KO8ERK+f7wvPrivB9IXjEX6grGYlBglbZu5/LD02VohBQAuX696c3SOugxRc9ah/dx1jX5nEQD0eH2LxftO/+kQouasQ8075wPf2Qag7hwzRNQ47FEhagEa+sW+dHJvnM0twpvrTgIAPp7YE7fFhwMAUl4ciu7zNuHFGm96rlb7HUdNYdgH+m+H1gmg08sb0DHEG6dziuDqrEDnMPMe7a5WXqmDm4v+/6+dyy3C9rRc3NolBFFBXgb3q+61aTd3PYK8lbhapNHbfuBCPnpF+ltUU21CCHyXcgEjY8MQpnJveAcznL5SiOdXHsX822LRvY2fVY9N9mvn6avwcHOSnhK0dxxMS9QCfP3HOSmEAMA7d3XHwYx8qDxc8Y9b2iPQuypwlJZr4WHGDLlCCLSbu97sesJV7rhcUGZ0+3+nJOKeL1LMPq4l/jmoPeaO7orSci3cXJwQ/eKN89kwcyB+3JMBf09XdArz0XuNQH36RQfixydvBgBcKy6HurQCecUaxLVWQelS9eerqdRKn2sSQiCroAz+nm5wcgIWbkjD//1ZNc5n9b/6oWdb8wPQCz8fxYr9mVj6aG8M6RwCoGrQds1zNXXQcnmlDot+O4UhnUOQEGX+L7rSci3yS8rrPPFlTKVWh+1puegV6Q9/A+OKyit1OJtbhC5hPmbNW9ScZF4rQXpeMQZ2DG6w7YW8YumWaOq8EfBxd23i6gzjYFoi0vP4gHZQebgi81oJnrylPXzcXXFv74g67cwJKQDM/sXwwqgueGpwtN66sgot+i34HSNiQrH1ZA4m9AhHbwt+AVrqy+Rz+DL5nMFtoxf9YdExd53Nw8trUvHEgPYY/N52vW2RgZ747IGbcM8XKbivdwReGx8DTaUO7q5Vf/btX1wPY//7eMfnu+oECp1O4GS2Gl3CfFGkqUSOugync4owPCYUl/JLkVVQhhX7MwEAjy7dhz5RAXh1fIw08Nkc29NyMHlp1ZNZn207i9NvjYaLkwK3vp+M8kodPp7YA77urgjxccfqQxcxKq4VQnyUeoOhu766EQCQ/NxgRAZW9ViVVWhRqRPwVur/StLpBF5cnYr/7L8IADj79hiUVmgxa/lhdAz1xgujuqDv21uRX1KBvu0CsOKfiXVqvl5SDi+lC05cVqN9sBd83F3x+19X8Ni3+zEpMRLzJ8TV2SeroBTrU7NxX++IOjXJbXtaDj7bdgYL7+oONxcnDFi4Tdr2zaQEDO0aWu/+Z3JuTOh4vaRCL6hUanVwcba/ESHsUSGiRlmxLwMvrEw1uj023BefPnATogI9Gww2NZ/wiZqzzqp1AlVPJ7UP9sK+9HyrH9tWOoR4S79sbosPx9oag4cb49Sbo+vcBqtmac9ZtfNJY6BQKJBdUIabk34DADw/qjMm94vCxfxSjPiwaobkfz/RF/2iA1FcrsXpK4X1Pg0GVP1ifnzZfmk5fcFYrDxwEfERKnQI8cFj3+7D7381/GbvuNa++HX6QOQWarDlxBUs3PgXCkor4OXmjOOvj9Jre+BCPvKKNOjbLhAqT1fkFJZB5eFqsHesWkFJBX45cgmBXkqM6RYGhUKBYk0lvMwIQab89xDso8S+l4ZJyztPX8XyfRl4dVwMQnzdIYTAa2uP47uUCwCAn568GTe3D4BCoZCOP3d0F0zo0Rq7zl7FhB6toRNC7wlCazHn9zeDChE12k97MzB3VSoCvNxwrbjqHUGJ7QPx5SO94Gth13KvN7Ygr9i89w31iQrA2O6t8PDNkXByUuCL5LM4n1uMFfszEdPKF+tnDpQmnrO250Z2xrqjWTjhoINpO4V649SVqgC0/+VhSHhzq9WOPWVQNGYN64gur2y02jENeW5kZ7y7Kc2qx0xfMBZTvj+Ajcez62z7dfoAqWeq5ndvnT0Id37+JwZ1DsGtXYLx9Ioj0j4L7+omBfsfn+iLfh0antVZXVaB7vM2m1Rv9zYqrP5Xfzg73QgfnUK9MXt4Z0z54YBJx6hteEwoljySYNG+xjhUUPn888/x7rvvIisrC7Gxsfjoo48wcKBpj/YxqBDZj5oTzFnDX9lqvPDzURy5eONFit8/3geebi7wcHVGTLgvyit1WHnwIiIDPdEv2rRp/BvbO/DvJ/riiroMs/9zBNHBXvh5Sj8AgL+XG6b++yDWpWZZfGxquZ4b2RlDOodIj9kDQJivO7LVxsdy1acx+xpi7YkXHSaorFixAg8//DA+//xz9O/fH19++SW+/vprnDhxAm3btm1wfwYVouZvX/o1pF4swCOJkVa7f7760EX8Z99FpJyrmvvll6n9cexyAV5afQz39GoDAeDnAxfh5eaMA68Mx71fpuDoxQL8/swgaTI+Q/6zP9Oqryc4+/YYOCmACq1Ap5c3WO24jdUuyAvnrxbLXQbZyNQh0XhuZBerHtNhgkrfvn1x0003YfHixdK6rl274vbbb0dSUlKd9hqNBhrNjccA1Wo1IiIiGFSIyCIVWh0qtcLsQcTG6HQCr/96An3bBaBzmA/O5RZjWEzV4MZDGfm44/NdmNinLZ4e1hFLd6UjIdIfwT5K/N/O8xjSJQQTerSWjqNQ6A9WfuY/R7Dy4EVEBXpWdcXXmPF3UKdg6ITAh/f1QJC3Ev/ecwGdQ32QEBWAgpIKuLooUFZR9Sj28UsF6NnWXxqPUlBagfj5pt1WOPDyMOkJMa1O4GSWGhuPZcPXwwVPDGiPywWl8Pd0k8ZeZOSV4JZ3t9V3SD0RAR7IvKY/B8/pt0bX6anLKihFYtLvJh+3JpWHKwpKK/DBvfGY/78TKCitsOg4crulUzB22OBdWD3b+uE//0y0+jgVhwgq5eXl8PT0xH//+1/ccccd0vqZM2fi8OHDSE5OrrPPvHnzMH/+/DrrGVSIiCwnhJBCUc3P1lBarsXxywXQCWDOyqPoFOqDN26PQ7GmEkcuXsfouFZwcVLUeU1CSXllgzMSl1VosfrQJRy7VIA7erZGXGuV9PSUpfKKNOj15laMig3D/AmxyLhWAh93F+w+m4cTWWrc36ctOof64NVfjuP+PhFIiPRHQWkF0vNK8OOeC3h1fCw0FVpkFZShXZAXvkw+i8cHtIe7mxM6v1w1Rmd0XBjKKrTYlmY4aLT280BrPw/sTb8GAOjTLgB7z1/DzheGoI2/p9HaZ684LM0aXe3Bvm3h4eqMyf2jsOtsHlQerhgRE4pL10sR5K2U/rwyr5Wglcpd6rUsKa+ETqDJnnpyiKBy+fJltG7dGn/++Sf69esnrX/77bexbNkypKXVHRDFHhUiIiLH51DzqNRO7vWleaVSCaWy6WfCJCIiIvsg28wuQUFBcHZ2Rna2/iNfOTk5CA2tf8IaIiIiahlkCypubm7o1asXtmzRfzHYli1b9G4FERERUcsl662f2bNn4+GHH0ZCQgISExPx1VdfISMjA1OmTJGzLCIiIrITsgaV++67D3l5eXj99deRlZWFuLg4rF+/HpGRkXKWRURERHZC9plpG4MTvhERETkec35/299rEomIiIj+xqBCREREdotBhYiIiOwWgwoRERHZLQYVIiIislsMKkRERGS3GFSIiIjIbjGoEBERkd2S/e3JjVE9V51arZa5EiIiIjJV9e9tU+acdeigUlhYCACIiIiQuRIiIiIyV2FhIVQqVb1tHHoKfZ1Oh8uXL8PHxwcKhcKqx1ar1YiIiEBmZmaznJ6f5+f4mvs58vwcW3M/P6D5n2NTnp8QAoWFhQgPD4eTU/2jUBy6R8XJyQlt2rRp0u/w9fVtln8Bq/H8HF9zP0een2Nr7ucHNP9zbKrza6gnpRoH0xIREZHdYlAhIiIiu8WgYoRSqcRrr70GpVIpdylNgufn+Jr7OfL8HFtzPz+g+Z+jvZyfQw+mJSIiouaNPSpERERktxhUiIiIyG4xqBAREZHdYlAhIiIiu8WgYsDnn3+Odu3awd3dHb169cIff/whd0l1zJs3DwqFQu8nLCxM2i6EwLx58xAeHg4PDw8MHjwYx48f1zuGRqPB9OnTERQUBC8vL9x22224ePGiXpv8/Hw8/PDDUKlUUKlUePjhh3H9+vUmOacdO3Zg/PjxCA8Ph0KhwJo1a/S22/KcMjIyMH78eHh5eSEoKAgzZsxAeXl5k57f5MmT61zTm2++2WHOLykpCb1794aPjw9CQkJw++23Iy0tTa+NI19DU87Pka/h4sWL0b17d2lyr8TERGzYsEHa7sjXztRzdOTrV1tSUhIUCgVmzZolrXPYayhIz/Lly4Wrq6tYsmSJOHHihJg5c6bw8vISFy5ckLs0Pa+99pqIjY0VWVlZ0k9OTo60fcGCBcLHx0esXLlSpKamivvuu0+0atVKqNVqqc2UKVNE69atxZYtW8TBgwfFkCFDRHx8vKisrJTajBo1SsTFxYldu3aJXbt2ibi4ODFu3LgmOaf169eLl156SaxcuVIAEKtXr9bbbqtzqqysFHFxcWLIkCHi4MGDYsuWLSI8PFxMmzatSc9v0qRJYtSoUXrXNC8vT6+NPZ/fyJEjxdKlS8WxY8fE4cOHxdixY0Xbtm1FUVGR1MaRr6Ep5+fI13Dt2rVi3bp1Ii0tTaSlpYkXX3xRuLq6imPHjgkhHPvamXqOjnz9atq7d6+IiooS3bt3FzNnzpTWO+o1ZFCppU+fPmLKlCl667p06SLmzJkjU0WGvfbaayI+Pt7gNp1OJ8LCwsSCBQukdWVlZUKlUokvvvhCCCHE9evXhaurq1i+fLnU5tKlS8LJyUls3LhRCCHEiRMnBACxe/duqU1KSooAIP76668mOKsbav8it+U5rV+/Xjg5OYlLly5JbX766SehVCpFQUFBk5yfEFX/SE6YMMHoPo50fkIIkZOTIwCI5ORkIUTzu4a1z0+I5ncN/f39xddff93srp2hcxSieVy/wsJC0bFjR7FlyxYxaNAgKag48jXkrZ8aysvLceDAAYwYMUJv/YgRI7Br1y6ZqjLu9OnTCA8PR7t27XD//ffj3LlzAIDz588jOztb7zyUSiUGDRoknceBAwdQUVGh1yY8PBxxcXFSm5SUFKhUKvTt21dqc/PNN0OlUtn8z8OW55SSkoK4uDiEh4dLbUaOHAmNRoMDBw406Xlu374dISEh6NSpE5588knk5ORI2xzt/AoKCgAAAQEBAJrfNax9ftWawzXUarVYvnw5iouLkZiY2OyunaFzrObo12/q1KkYO3Yshg0bprfeka+hQ7+U0NquXr0KrVaL0NBQvfWhoaHIzs6WqSrD+vbti++++w6dOnXClStX8Oabb6Jfv344fvy4VKuh87hw4QIAIDs7G25ubvD396/Tpnr/7OxshISE1PnukJAQm/952PKcsrOz63yPv78/3NzcmvS8R48ejXvuuQeRkZE4f/48XnnlFdx66604cOAAlEqlQ52fEAKzZ8/GgAEDEBcXJ31vdb2163e0a2jo/ADHv4apqalITExEWVkZvL29sXr1asTExEi/gJrDtTN2joDjX7/ly5fj4MGD2LdvX51tjvzfH4OKAQqFQm9ZCFFnndxGjx4tfe7WrRsSExMRHR2NZcuWSYO/LDmP2m0MtZfzz8NW5yTHed93333S57i4OCQkJCAyMhLr1q3DnXfeaXQ/ezy/adOm4ejRo9i5c2edbc3hGho7P0e/hp07d8bhw4dx/fp1rFy5EpMmTUJycrLR73TEa2fsHGNiYhz6+mVmZmLmzJnYvHkz3N3djbZzxGvIWz81BAUFwdnZuU7iy8nJqZMO7Y2Xlxe6deuG06dPS0//1HceYWFhKC8vR35+fr1trly5Uue7cnNzbf7nYctzCgsLq/M9+fn5qKiosOl5t2rVCpGRkTh9+rRUlyOc3/Tp07F27Vps27YNbdq0kdY3l2to7PwMcbRr6Obmhg4dOiAhIQFJSUmIj4/HokWLms21q+8cDXGk63fgwAHk5OSgV69ecHFxgYuLC5KTk/Hxxx/DxcVFOq5DXkOzR7U0c3369BFPPfWU3rquXbva3WDa2srKykTr1q3F/PnzpUFTCxculLZrNBqDg6ZWrFghtbl8+bLBQVN79uyR2uzevVvWwbS2OKfqgWCXL1+W2ixfvrzJB9PWdvXqVaFUKsWyZcsc4vx0Op2YOnWqCA8PF6dOnTK43ZGvYUPnZ4ijXcPabr31VjFp0iSHv3amnKMhjnT91Gq1SE1N1ftJSEgQDz30kEhNTXXoa8igUkv148nffPONOHHihJg1a5bw8vIS6enpcpem55lnnhHbt28X586dE7t37xbjxo0TPj4+Up0LFiwQKpVKrFq1SqSmpoqJEycafAytTZs2YuvWreLgwYPi1ltvNfgYWvfu3UVKSopISUkR3bp1a7LHkwsLC8WhQ4fEoUOHBADxwQcfiEOHDkmPhtvqnKofrRs6dKg4ePCg2Lp1q2jTpk2jHx2s7/wKCwvFM888I3bt2iXOnz8vtm3bJhITE0Xr1q0d5vyeeuopoVKpxPbt2/Ue7ywpKZHaOPI1bOj8HP0azp07V+zYsUOcP39eHD16VLz44ovCyclJbN68WQjh2NfOlHN09OtnSM2nfoRw3GvIoGLAZ599JiIjI4Wbm5u46aab9B4/tBfVz7+7urqK8PBwceedd4rjx49L23U6nXjttddEWFiYUCqV4pZbbhGpqal6xygtLRXTpk0TAQEBwsPDQ4wbN05kZGTotcnLyxMPPvig8PHxET4+PuLBBx8U+fn5TXJO27ZtEwDq/FT/344tz+nChQti7NixwsPDQwQEBIhp06aJsrKyJju/kpISMWLECBEcHCxcXV1F27ZtxaRJk+rUbs/nZ+jcAIilS5dKbRz5GjZ0fo5+DR977DHp373g4GAxdOhQKaQI4djXzpRzdPTrZ0jtoOKo11AhhBDm3zAiIiIianocTEtERER2i0GFiIiI7BaDChEREdktBhUiIiKyWwwqREREZLcYVIiIiMhuMagQERGR3WJQISIiIrvFoEJEjRIVFYWPPvrI5Pbbt2+HQqHA9evXm6wmImo+ODMtUQszePBg9OjRw6xwUZ/c3Fx4eXnB09PTpPbl5eW4du0aQkNDLX6lfWNt374dQ4YMQX5+Pvz8/GSpgYhM4yJ3AURkf4QQ0Gq1cHFp+J+I4OBgs47t5uaGsLAwS0sjohaGt36IWpDJkycjOTkZixYtgkKhgEKhQHp6unQ7ZtOmTUhISIBSqcQff/yBs2fPYsKECQgNDYW3tzd69+6NrVu36h2z9q0fhUKBr7/+GnfccQc8PT3RsWNHrF27Vtpe+9bPt99+Cz8/P2zatAldu3aFt7c3Ro0ahaysLGmfyspKzJgxA35+fggMDMQLL7yASZMm4fbbbzd6rhcuXMD48ePh7+8PLy8vxMbGYv369UhPT8eQIUMAAP7+/lAoFJg8eTKAqoD2zjvvoH379vDw8EB8fDx+/vnnOrWvW7cO8fHxcHd3R9++fZGammrhFSGihjCoELUgixYtQmJiIp588klkZWUhKysLERER0vbnn38eSUlJOHnyJLp3746ioiKMGTMGW7duxaFDhzBy5EiMHz8eGRkZ9X7P/Pnzce+99+Lo0aMYM2YMHnzwQVy7ds1o+5KSErz33nv4/vvvsWPHDmRkZODZZ5+Vti9cuBD//ve/sXTpUvz5559Qq9VYs2ZNvTVMnToVGo0GO3bsQGpqKhYuXAhvb29ERERg5cqVAIC0tDRkZWVh0aJFAICXX34ZS5cuxeLFi3H8+HE8/fTTeOihh5CcnKx37Oeeew7vvfce9u3bh5CQENx2222oqKiotx4ispBF71wmIodV+9XvQgixbds2AUCsWbOmwf1jYmLEJ598Ii1HRkaKDz/8UFoGIF5++WVpuaioSCgUCrFhwwa976p+LfzSpUsFAHHmzBlpn88++0yEhoZKy6GhoeLdd9+VlisrK0Xbtm3FhAkTjNbZrVs3MW/ePIPbatdQXae7u7vYtWuXXtvHH39cTJw4UW+/5cuXS9vz8vKEh4eHWLFihdFaiMhyHKNCRJKEhAS95eLiYsyfPx+//vorLl++jMrKSpSWljbYo9K9e3fps5eXF3x8fJCTk2O0vaenJ6Kjo6XlVq1aSe0LCgpw5coV9OnTR9ru7OyMXr16QafTGT3mjBkz8NRTT2Hz5s0YNmwY7rrrLr26ajtx4gTKysowfPhwvfXl5eXo2bOn3rrExETpc0BAADp37oyTJ08aPTYRWY5BhYgkXl5eesvPPfccNm3ahPfeew8dOnSAh4cH7r77bpSXl9d7HFdXV71lhUJRb6gw1F7UeiCx9hNCtbfX9sQTT2DkyJFYt24dNm/ejKSkJLz//vuYPn26wfbV9a1btw6tW7fW26ZUKuv9LkP1EZF1cIwKUQvj5uYGrVZrUts//vgDkydPxh133IFu3bohLCwM6enpTVtgLSqVCqGhodi7d6+0TqvV4tChQw3uGxERgSlTpmDVqlV45plnsGTJEgBVfwbVx6kWExMDpVKJjIwMdOjQQe+n5jgeANi9e7f0OT8/H6dOnUKXLl0adZ5EZBh7VIhamKioKOzZswfp6enw9vZGQECA0bYdOnTAqlWrMH78eCgUCrzyyiv19ow0lenTpyMpKQkdOnRAly5d8MknnyA/P7/eXoxZs2Zh9OjR6NSpE/Lz8/H777+ja9euAIDIyEgoFAr8+uuvGDNmDDw8PODj44Nnn30WTz/9NHQ6HQYMGAC1Wo1du3bB29sbkyZNko79+uuvIzAwEKGhoXjppZcQFBRU7xNIRGQ59qgQtTDPPvssnJ2dERMTg+Dg4HrHm3z44Yfw9/dHv379MH78eIwcORI33XSTDaut8sILL2DixIl45JFHkJiYCG9vb4wcORLu7u5G99FqtZg6dSq6du2KUaNGoXPnzvj8888BAK1bt8b8+fMxZ84chIaGYtq0aQCAN954A6+++iqSkpLQtWtXjBw5Ev/73//Qrl07vWMvWLAAM2fORK9evZCVlYW1a9dKvTREZF2cmZaIHI5Op0PXrl1x77334o033rDZ93JGWyLb460fIrJ7Fy5cwObNmzFo0CBoNBp8+umnOH/+PB544AG5SyOiJsZbP0Rk95ycnPDtt9+id+/e6N+/P1JTU7F161ZpzAkRNV+89UNERER2iz0qREREZLcYVIiIiMhuMagQERGR3WJQISIiIrvFoEJERER2i0GFiIiI7BaDChEREdktBhUiIiKyW/8P8oi0Mia5KicAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.optim import Adam\n",
    "import numpy as np\n",
    "\n",
    "# 训练序列到序列模型\n",
    "def train_seq2seq_mt(train_data, encoder, decoder, epochs=20,\\\n",
    "        learning_rate=1e-3):\n",
    "    # 准备模型和优化器\n",
    "    encoder_optimizer = Adam(encoder.parameters(), lr=learning_rate)\n",
    "    decoder_optimizer = Adam(decoder.parameters(), lr=learning_rate)\n",
    "    criterion = nn.NLLLoss()\n",
    "\n",
    "    encoder.train()\n",
    "    decoder.train()\n",
    "    encoder.zero_grad()\n",
    "    decoder.zero_grad()\n",
    "\n",
    "    step_losses = []\n",
    "    plot_losses = []\n",
    "    with trange(n_epochs, desc='epoch', ncols=60) as pbar:\n",
    "        for epoch in pbar:\n",
    "            np.random.shuffle(train_data)\n",
    "            for step, data in enumerate(train_data):\n",
    "                # 将源序列和目标序列转为 1 * seq_len 的tensor\n",
    "                # 这里为了简单实现，采用了批次大小为1，\n",
    "                # 当批次大小大于1时，编码器需要进行填充\n",
    "                # 并且返回最后一个非填充词的隐状态，\n",
    "                # 解码也需要进行相应的处理\n",
    "                input_ids, target_ids = data\n",
    "                input_tensor, target_tensor = \\\n",
    "                    torch.tensor(input_ids).unsqueeze(0),\\\n",
    "                    torch.tensor(target_ids).unsqueeze(0)\n",
    "\n",
    "                encoder_optimizer.zero_grad()\n",
    "                decoder_optimizer.zero_grad()\n",
    "\n",
    "                encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "                # 输入目标序列用于teacher forcing训练\n",
    "                decoder_outputs, _, _ = decoder(encoder_outputs,\\\n",
    "                    encoder_hidden, target_tensor)\n",
    "\n",
    "                loss = criterion(\n",
    "                    decoder_outputs.view(-1, decoder_outputs.size(-1)),\n",
    "                    target_tensor.view(-1)\n",
    "                )\n",
    "                pbar.set_description(f'epoch-{epoch}, '+\\\n",
    "                    f'loss={loss.item():.4f}')\n",
    "                step_losses.append(loss.item())\n",
    "                # 实际训练批次为1，训练损失波动过大\n",
    "                # 将多步损失求平均可以得到更平滑的训练曲线，便于观察\n",
    "                plot_losses.append(np.mean(step_losses[-32:]))\n",
    "                loss.backward()\n",
    "\n",
    "                encoder_optimizer.step()\n",
    "                decoder_optimizer.step()\n",
    "\n",
    "    plot_losses = np.array(plot_losses)\n",
    "    plt.plot(range(len(plot_losses)), plot_losses)\n",
    "    plt.xlabel('training step')\n",
    "    plt.ylabel('loss')\n",
    "    plt.show()\n",
    "\n",
    "    \n",
    "hidden_size = 128\n",
    "n_epochs = 20\n",
    "learning_rate = 1e-3\n",
    "\n",
    "encoder = RNNEncoder(input_lang.n_words, hidden_size)\n",
    "decoder = AttnRNNDecoder(output_lang.n_words, hidden_size)\n",
    "\n",
    "train_seq2seq_mt(train_data, encoder, decoder, n_epochs, learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be47e115",
   "metadata": {},
   "source": [
    "下面实现贪心搜索解码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "678192b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input： 婚 礼 摄 影 幸 福 攻 略\n",
      "target： wedding photography is a happy attempt .\n",
      "pred： wedding photography is a happy attempt .\n",
      "\n",
      "input： 庖 丁 解 牛 l i n u x 操 作 系 统 分 析\n",
      "target： analysis of the linux operating system for the ding-ding cow\n",
      "pred： analysis of the linux operating system for the ding-ding cow of the linux operating system\n",
      "\n",
      "input： 做 自 己 的 太 阳 无 需 凭 借 谁 的 光\n",
      "target： you don't have to use the light to be your own sun .\n",
      "pred： how to be you don't know market your energy photographer music spectrometers .\n",
      "\n",
      "input： 绿 色 制 造 系 统 集 成 项 目 典 型 案 例\n",
      "target： a typical case of the green manufacturing system integration project\n",
      "pred： the 3d of human resources management and forecasting for the procurement process management .\n",
      "\n",
      "input： 电 商 设 计 技 巧 修 炼 与 实 战 应 用\n",
      "target： vendor design techniques refining and operational applications\n",
      "pred： vendor design techniques refining and operational applications\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "def greedy_decode(encoder, decoder, sentence, input_lang, output_lang):\n",
    "    with torch.no_grad():\n",
    "        # 将源序列转为 1 * seq_length 的tensor\n",
    "        input_ids = input_lang.sent2ids(sentence)\n",
    "        input_tensor = torch.tensor(input_ids).unsqueeze(0)\n",
    "        \n",
    "        encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "        decoder_outputs, decoder_hidden, decoder_attn = \\\n",
    "            decoder(encoder_outputs, encoder_hidden)\n",
    "        \n",
    "        # 取出每一步预测概率最大的词\n",
    "        _, topi = decoder_outputs.topk(1)\n",
    "        \n",
    "        decoded_ids = []\n",
    "        for idx in topi.squeeze():\n",
    "            if idx.item() == EOS_token:\n",
    "                break\n",
    "            decoded_ids.append(idx.item())\n",
    "    return output_lang.ids2sent(decoded_ids), decoder_attn\n",
    "            \n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "for i in range(5):\n",
    "    pair = random.choice(pairs)\n",
    "    print('input：', pair[0])\n",
    "    print('target：', pair[1])\n",
    "    output_sentence, _ = greedy_decode(encoder, decoder, pair[0],\n",
    "        input_lang, output_lang)\n",
    "    print('pred：', output_sentence)\n",
    "    print('')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e38320f0",
   "metadata": {},
   "source": [
    "\n",
    "接下来使用束搜索解码来验证模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3496efa3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input： 体 验 式 营 销 世 界 上 伟 大 品 牌 的 成 功 秘 决 及 营 销 策 略\n",
      "target： empirical marketing , the success of the world's great brands , and marketing strategies .\n",
      "pred： empirical marketing , the success of the world's great brands , and marketing strategies .\n",
      "\n",
      "input： 母 带 处 理 ： 母 带 制 作 技 术 与 艺 术 （ 第 2 版 ）\n",
      "target： material delivery: material production technology and art (version 2)\n",
      "pred： material delivery: material production technology and art (version 2)\n",
      "\n",
      "input： 短 视 频 运 营 ： 从 入 门 到 精 通 （ 微 课 版 ）\n",
      "target： short video operation: from entry to mastery (microtext)\n",
      "pred： short video operation: from entry to mastery (microtext)\n",
      "\n",
      "input： 世 界 绘 画 经 典 教 程 — — 跟 巴 伯 学 素 描 （ 第 2 版 ）\n",
      "target： world painting classic curriculum - with barber's psychiatry . 2nd ed .\n",
      "pred： world painting classic curriculum - with barber's psychiatry . 2nd ed .\n",
      "\n",
      "input： 审 计 学 理 论 案 例 与 实 务\n",
      "target： audit science , theory , case and practice\n",
      "pred： theory , theory , case and practice\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 定义容器类用于管理所有的候选结果\n",
    "class BeamHypotheses:\n",
    "    def __init__(self, num_beams, max_length):\n",
    "        self.max_length = max_length\n",
    "        self.num_beams = num_beams\n",
    "        self.beams = []\n",
    "        self.worst_score = 1e9\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.beams)\n",
    "    \n",
    "    # 添加一个候选结果，更新最差得分\n",
    "    def add(self, sum_logprobs, hyp, hidden):\n",
    "        score = sum_logprobs / max(len(hyp), 1)\n",
    "        if len(self) < self.num_beams or score > self.worst_score:\n",
    "            # 可更新的情况：数量未饱和或超过最差得分\n",
    "            self.beams.append((score, hyp, hidden))\n",
    "            if len(self) > self.num_beams:\n",
    "                # 数量饱和需要删掉一个最差的\n",
    "                sorted_scores = sorted([(s, idx) for idx,\\\n",
    "                    (s, _, _) in enumerate(self.beams)])\n",
    "                del self.beams[sorted_scores[0][1]]\n",
    "                self.worst_score = sorted_scores[1][0]\n",
    "            else:\n",
    "                self.worst_score = min(score, self.worst_score)\n",
    "    \n",
    "    # 取出一个未停止的候选结果，第一个返回值表示是否成功取出，\n",
    "    # 如成功，则第二个值为目标候选结果\n",
    "    def pop(self):\n",
    "        if len(self) == 0:\n",
    "            return False, None\n",
    "        for i, (s, hyp, hid) in enumerate(self.beams):\n",
    "            # 未停止的候选结果需满足：长度小于最大解码长度；不以<eos>结束\n",
    "            if len(hyp) < self.max_length and (len(hyp) == 0\\\n",
    "                    or hyp[-1] != EOS_token):\n",
    "                del self.beams[i]\n",
    "                if len(self) > 0:\n",
    "                    sorted_scores = sorted([(s, idx) for idx,\\\n",
    "                        (s, _, _) in enumerate(self.beams)])\n",
    "                    self.worst_score = sorted_scores[0][0]\n",
    "                else:\n",
    "                    self.worst_score = 1e9\n",
    "                return True, (s, hyp, hid)\n",
    "        return False, None\n",
    "    \n",
    "    # 取出分数最高的候选结果，第一个返回值表示是否成功取出，\n",
    "    # 如成功，则第二个值为目标候选结果\n",
    "    def pop_best(self):\n",
    "        if len(self) == 0:\n",
    "            return False, None\n",
    "        sorted_scores = sorted([(s, idx) for idx, (s, _, _)\\\n",
    "            in enumerate(self.beams)])\n",
    "        return True, self.beams[sorted_scores[-1][1]]\n",
    "\n",
    "\n",
    "def beam_search_decode(encoder, decoder, sentence, input_lang,\n",
    "        output_lang, num_beams=3):\n",
    "    with torch.no_grad():\n",
    "        # 将源序列转为 1 * seq_length 的tensor\n",
    "        input_ids = input_lang.sent2ids(sentence)\n",
    "        input_tensor = torch.tensor(input_ids).unsqueeze(0)\n",
    "\n",
    "        # 在容器中插入一个空的候选结果\n",
    "        encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "        init_hyp = []\n",
    "        hypotheses = BeamHypotheses(num_beams, MAX_LENGTH)\n",
    "        hypotheses.add(0, init_hyp, encoder_hidden)\n",
    "\n",
    "        while True:\n",
    "            # 每次取出一个未停止的候选结果\n",
    "            flag, item = hypotheses.pop()\n",
    "            if not flag:\n",
    "                break\n",
    "                \n",
    "            score, hyp, decoder_hidden = item\n",
    "            \n",
    "            # 当前解码器输入\n",
    "            if len(hyp) > 0:\n",
    "                decoder_input = torch.empty(1, 1,\\\n",
    "                    dtype=torch.long).fill_(hyp[-1])\n",
    "            else:\n",
    "                decoder_input = torch.empty(1, 1,\\\n",
    "                    dtype=torch.long).fill_(SOS_token)\n",
    "\n",
    "            # 解码一步\n",
    "            decoder_output, decoder_hidden, _ = decoder.forward_step(\n",
    "                decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "\n",
    "            # 从输出分布中取出前k个结果\n",
    "            topk_values, topk_ids = decoder_output.topk(num_beams)\n",
    "            # 生成并添加新的候选结果到容器\n",
    "            for logp, token_id in zip(topk_values.squeeze(),\\\n",
    "                    topk_ids.squeeze()):\n",
    "                sum_logprobs = score * len(hyp) + logp.item()\n",
    "                new_hyp = hyp + [token_id.item()]\n",
    "                hypotheses.add(sum_logprobs, new_hyp, decoder_hidden)\n",
    "\n",
    "        flag, item = hypotheses.pop_best()\n",
    "        if flag:\n",
    "            hyp = item[1]\n",
    "            if hyp[-1] == EOS_token:\n",
    "                del hyp[-1]\n",
    "            return output_lang.ids2sent(hyp)\n",
    "        else:\n",
    "            return ''\n",
    "\n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "for i in range(5):\n",
    "    pair = random.choice(pairs)\n",
    "    print('input：', pair[0])\n",
    "    print('target：', pair[1])\n",
    "    output_sentence = beam_search_decode(encoder, decoder,\\\n",
    "        pair[0], input_lang, output_lang)\n",
    "    print('pred：', output_sentence)\n",
    "    print('')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
